{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5842113a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "# 一个形状为 (batch_size, seq_len, feature_dim) 的张量 x\n",
    "x = torch.randn(2, 3, 4)\n",
    "# 计算原始权重，形状为 (batch_size, seq_len, seq_len)\n",
    "raw_weights = torch.bmm(x, x.transpose(1, 2))\n",
    "# 对原始权重进行 softmax 归一化，形状为 (batch_size, seq_len, seq_len)\n",
    "attn_weights = F.softmax(raw_weights, dim=2)\n",
    "# 计算加权和，形状为 (batch_size, seq_len, feature_dim) \n",
    "attn_outputs = torch.bmm(attn_weights, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b7ec4b56",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 加权信息 : tensor([[[ 0.5676, -0.0132, -0.8214, -0.0548],\n",
      "         [ 0.5352, -0.1170, -0.5392, -0.0256],\n",
      "         [ 0.6141, -0.1343, -0.5587, -0.0331]],\n",
      "\n",
      "        [[ 0.5973, -0.2426, -0.3217, -0.0335],\n",
      "         [ 0.5996, -0.1914, -0.2840,  0.0152],\n",
      "         [ 0.6117, -0.2507, -0.3363, -0.0404]]], grad_fn=<BmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# 一个形状为 (batch_size, seq_len, feature_dim) 的张量 x\n",
    "x = torch.randn(2, 3, 4) # 形状 (batch_size, seq_len, feature_dim)\n",
    "# 定义线性层用于将 x 转换为 Q, K, V 向量\n",
    "linear_q = torch.nn.Linear(4, 4)\n",
    "linear_k = torch.nn.Linear(4, 4)\n",
    "linear_v = torch.nn.Linear(4, 4)\n",
    "# 通过线性层计算 Q, K, V\n",
    "Q = linear_q(x) # 形状 (batch_size, seq_len, feature_dim)\n",
    "K = linear_k(x) # 形状 (batch_size, seq_len, feature_dim)\n",
    "V = linear_v(x) # 形状 (batch_size, seq_len, feature_dim)\n",
    "# 计算 Q 和 K 的点积，作为相似度分数 , 也就是自注意力原始权重\n",
    "raw_weights = torch.bmm(Q, K.transpose(1, 2)) # 形状 (batch_size, seq_len, seq_len)\n",
    "# 将自注意力原始权重进行缩放\n",
    "scale_factor = K.size(-1) ** 0.5  # 这里是 4 ** 0.5\n",
    "scaled_weights = raw_weights / scale_factor # 形状 (batch_size, seq_len, seq_len)\n",
    "# 对缩放后的权重进行 softmax 归一化，得到注意力权重\n",
    "attn_weights = F.softmax(scaled_weights, dim=2) # 形状 (batch_size, seq_len, seq_len)\n",
    "# 将注意力权重应用于 V 向量，计算加权和，得到加权信息\n",
    "attn_outputs = torch.bmm(attn_weights, V) # 形状 (batch_size, seq_len, feature_dim)\n",
    "print(\" 加权信息 :\", attn_outputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
